// Copyright 2014 Google Inc. All Rights Reserved.

#include "common.h"
#include "ChannelManager.h"
#include "Controller.h"
#include "MessageRouter.h"
#include "ProtocolEndpointBase.h"

bool MessageRouter::init(ChannelManager* channelManager) {
    mChannelManager = channelManager;
    memset(mChannelServiceIdMap, SENTINEL_SERVICE_ID, sizeof(mChannelServiceIdMap));
    memset(mServiceMap, 0, sizeof(mServiceMap));
    return true;
}

void MessageRouter::shutdown() {
    mChannelManager = NULL;
}

void MessageRouter::setupMapping(uint8_t serviceId, uint8_t channelId) {
    mChannelServiceIdMap[channelId] = serviceId;
}

int MessageRouter::handleChannelOpenReq(uint8_t channelId, const ChannelOpenRequest& req) {
    if (mChannelServiceIdMap[channelId] != SENTINEL_SERVICE_ID) {
        return STATUS_INVALID_CHANNEL;
    }

    uint8_t serviceId = (uint8_t) req.service_id();
    if (serviceId >= SENTINEL_SERVICE_ID) {
        return STATUS_INVALID_SERVICE;
    }
    ProtocolEndpointBase* service = mServiceMap[serviceId];
    if (service == NULL) {
        return STATUS_INVALID_SERVICE;
    }
    if (!service->mayOpenChannel(channelId)) {
        return STATUS_BUSY;
    }

    setupMapping(serviceId, channelId);
    return STATUS_SUCCESS;
}

void MessageRouter::notifyChannelOpened(uint8_t channelId, const ChannelOpenRequest& req) {
    uint8_t serviceId = (uint8_t) req.service_id();
    if (serviceId >= SENTINEL_SERVICE_ID) {
        LOG("Failed to notify channel open channel:%d service:%d", channelId, req.service_id());
        return;
    }
    ProtocolEndpointBase* service = mServiceMap[serviceId];
    if (service != NULL) {
        service->onChannelOpened(channelId);
    }
}

int MessageRouter::sendChannelOpenResp(uint8_t channelId, int32_t status) {
    ChannelOpenResponse resp;
    resp.set_status(static_cast<MessageStatus>(status));
    IoBuffer buf;
    marshallProto(MESSAGE_CHANNEL_OPEN_RESPONSE, resp, &buf);
    return mChannelManager->queueOutgoing(channelId, true, buf.raw(), buf.size());
}

int MessageRouter::handleChannelCloseNotif(uint8_t channelId,
                                           const ChannelCloseNotification& notification){
    if (mChannelServiceIdMap[channelId] == SENTINEL_SERVICE_ID) {
        return STATUS_INVALID_CHANNEL;
    }

    ProtocolEndpointBase* service = mServiceMap[mChannelServiceIdMap[channelId]];
    if (service == NULL) {
        return STATUS_INVALID_SERVICE;
    }
    if (!service->onChannelClosed(channelId)) {
        return STATUS_INTERNAL_ERROR;
    }
    mChannelServiceIdMap[channelId] = SENTINEL_SERVICE_ID;
    return STATUS_SUCCESS;
}

int MessageRouter::sendUnexpectedMessage(uint8_t channelId) {
    uint16_t buf;
    LOGW("Sending unexpected message on channel %d", channelId);
    WRITE_BE16(&buf, MESSAGE_UNEXPECTED_MESSAGE);
    return mChannelManager->queueOutgoing(channelId, false, &buf, sizeof(uint16_t));
}

int MessageRouter::routeChannelControlMsg(const shared_ptr<Frame>& frame,
                                          void* message, size_t len) {
    uint16_t type = extractType((uint8_t*)message);
    int status = STATUS_INVALID_CHANNEL;
    uint8_t* ptr = (uint8_t*) message + sizeof(uint16_t);
    len -= sizeof(uint16_t);

    switch (type) {
    case MESSAGE_CHANNEL_OPEN_REQUEST: {
        ChannelOpenRequest req;
        if (!PARSE_PROTO(req, ptr, len)) {
            break;
        }
        status = handleChannelOpenReq(frame->channelId, req);
        if (status != STATUS_FRAMING_ERROR) {
            if (status == STATUS_SUCCESS) {
                status = mChannelManager->registerChannel(frame->channelId, req.priority());
            }
            sendChannelOpenResp(frame->channelId, status);
            // Must be after channel open response is sent.
            if (status == STATUS_SUCCESS) {
                notifyChannelOpened(frame->channelId, req);
            }
        }
        break;
    }

    case MESSAGE_CHANNEL_CLOSE_NOTIFICATION: {
        ChannelCloseNotification notification;
        if (!PARSE_PROTO(notification, ptr, len)) {
            break;
        }
        status = handleChannelCloseNotif(frame->channelId, notification);
        if (status != STATUS_FRAMING_ERROR) {
            if (status == STATUS_SUCCESS) {
                status = mChannelManager->unregisterChannel(frame->channelId);
            }
        }
        break;
    }

    default:
        status = sendUnexpectedMessage(frame->channelId);
        assert(status == STATUS_SUCCESS);
    }

    return status;
}

uint16_t MessageRouter::extractType(uint8_t* message) {
    uint16_t type = (uint16_t)*message++ << 8;
    type |= *message;
    return type;
}

int MessageRouter::routeMessage(uint8_t channelId, const shared_ptr<IoBuffer>& message) {
    uint8_t serviceId = mChannelServiceIdMap[channelId];
    if (serviceId == SENTINEL_SERVICE_ID) {
        return STATUS_INVALID_CHANNEL;
    }

    ProtocolEndpointBase* service = mServiceMap[serviceId];
    if (service == NULL) {
        return STATUS_INVALID_SERVICE;
    }

    int status;
    if (service->isPassthrough()) {
        status = service->handleRawMessage(channelId, message);
    } else {
        uint16_t type = extractType((uint8_t*)message->raw());
        status = service->routeMessage(channelId, type, message);
    }
    if (status == STATUS_UNEXPECTED_MESSAGE) {
        status = sendUnexpectedMessage(channelId);
    }

    return status;
}

bool MessageRouter::registerService(ProtocolEndpointBase* service) {
    if (service->id() >= SENTINEL_SERVICE_ID) {
        return false;
    }
    if (mServiceMap[service->id()] != NULL) {
        return false;
    }
    mServiceMap[service->id()] = service;
    return true;
}

int MessageRouter::queueOutgoing(uint8_t channelId, void* buf, size_t len) {
    return mChannelManager->queueOutgoing(channelId, false, buf, len);
}

int MessageRouter::queueOutgoingUnencrypted(uint8_t channelId, void* buf, size_t len) {
    return mChannelManager->queueOutgoingUnencrypted(channelId, false, buf, len);
}

void MessageRouter::populateServiceDiscoveryResponse(ServiceDiscoveryResponse* sdr) {
    for (int i = 0; i < MAX_SERVICES; i++) {
        if (mServiceMap[i] != NULL) {
            ProtocolEndpointBase* endpoint = mServiceMap[i];
            endpoint->addDiscoveryInfo(sdr);
        }
    }
}

void MessageRouter::marshallProto(uint16_t type, const google::protobuf::MessageLite& proto,
        IoBuffer* out) {
    size_t len = proto.ByteSize();
    out->resize(len + sizeof(uint16_t));
    uint8_t* ptr = (uint8_t*) out->raw();
    WRITE_BE16(ptr, type);
    proto.SerializeToArray(ptr + sizeof(uint16_t), len);
}

void MessageRouter::unrecoverableError(MessageStatus err) {
    Controller* controller = (Controller*) mServiceMap[CONTROLLER_SERVICE_ID];
    controller->unrecoverableError(err);
}

bool MessageRouter::closeChannel(uint8_t channelId) {
    if (mChannelServiceIdMap[channelId] != SENTINEL_SERVICE_ID) {
        ChannelCloseNotification notification;
        IoBuffer buf;
        marshallProto(MESSAGE_CHANNEL_CLOSE_NOTIFICATION, notification, &buf);
        mChannelManager->queueOutgoing(channelId, true, buf.raw(), buf.size());
        mChannelServiceIdMap[channelId] = SENTINEL_SERVICE_ID;
        return true;
    }
    return false;
}
